import numpy as np

# Sollya, fp32 c_0~c_{n-1}, bf16 c_n.
SIGMOID_POSITIVE_PLOY_COEFFCIENTS = {
    # sollya command: fpminimax(1/(1+exp(-x)), 3, [|24,24,24,8|], [0.0; 7.0], floating, absolute);
    3: [
        0.500168383121490478515625, 0.280329048633575439453125,
        -5.28283752501010894775390625e-2, 3.2958984375e-3
    ],
    # sollya command: fpminimax(1/(1+exp(-x)), 4, [|24,24,24,24,8|], [0; 7.0], floating, absolute);
    4: [
        0.4957107007503509521484375, 0.2901675403118133544921875,
        -5.83785437047481536865234375e-2, 4.4221165589988231658935546875e-3,
        -7.4863433837890625e-5
    ],
    # sollya command: fpminimax(1/(1+exp(-x)), 5, [|24,24,24,24,24,8|], [0.0; 7.0], floating, absolute);
    5: [
        0.497888505458831787109375, 0.271907985210418701171875,
        -3.463833034038543701171875e-2, -6.04884326457977294921875e-3,
        1.767211011610925197601318359375e-3, -1.1157989501953125e-4
    ],
    # # sollya command: fpminimax(1/(1+exp(-x)), 6, [|24,24,24,24,24,24,8|], [0.0; 7.0], floating, absolute);
    # 6: [0.4994693100452423095703125, 0.2547923624515533447265625, -4.6910843811929225921630859375e-3, -2.507667057216167449951171875e-2, 7.193140685558319091796875e-3, -8.205013000406324863433837890625e-4, 3.45706939697265625e-5],
    # # sollya command: fpminimax(1/(1+exp(-x)), 7, [|24,24,24,24,24,24,24,8|], [0.0; 7.0], floating, absolute);
    # 7: [0.500115573406219482421875, 0.24712221324443817138671875, 1.17001868784427642822265625e-2, -3.86395789682865142822265625e-2, 1.26150362193584442138671875e-2, -1.94144877605140209197998046875e-3, 1.49970306665636599063873291015625e-4, -4.6789646148681640625e-6],
    # # sollya command: fpminimax(1/(1+exp(-x)), 8, [|24,24,24,24,24,24,24,24,8|], [0.0; 7.0], floating, absolute);
    # 8: [0.500094711780548095703125, 0.24758051335811614990234375, 1.0084983892738819122314453125e-2, -3.65161038935184478759765625e-2, 1.12692527472972869873046875e-2, -1.485520391725003719329833984375e-3, 6.552544073201715946197509765625e-5, 3.37545134243555366992950439453125e-6, -3.091990947723388671875e-7],
    # # sollya commang: fpminimax(1/(1+exp(-x)), 9, [|24,24,24,24,24,24,24,24,24,8|], [0.0; 7.0], floating, absolute);
    # 9: [0.500036716461181640625, 0.24901355803012847900390625, 4.282380454242229461669921875e-3, -2.753300964832305908203125e-2, 4.350929521024227142333984375e-3, 1.4890762977302074432373046875e-3, -6.818011752329766750335693359375e-4, 1.12315974547527730464935302734375e-4, -8.8410370153724215924739837646484375e-6, 2.7753412723541259765625e-7],
    # # sollya commang: fpminimax(1/(1+exp(-x)), 10, [|24,24,24,24,24,24,24,24,24,24,8|], [0.0; 7.0], floating, absolute);
    # 10: [0.500007152557373046875, 0.2498905658721923828125, 7.46184741728939116001129150390625e-5, -1.971541158854961395263671875e-2, -3.019575960934162139892578125e-3, 5.4769436828792095184326171875e-3, -1.99403869919478893280029296875e-3, 3.792881616391241550445556640625e-4, -4.163103949395008385181427001953125e-5, 2.505129259589011780917644500732421875e-6, -6.4261257648468017578125e-8],
}

# Sollya, fp32 c_0~c_{n-1}, bf16 c_n.
EXP_POLY_COEFFCIENTS = {
    1: {
        2: [0.99975693225860595703125, 1.015083789825439453125, 0.5078125],
        3: [
            0.999924480915069580078125, 0.999958038330078125,
            0.505023539066314697265625, 0.16796875
        ],
        4: [
            1.0000002384185791015625, 0.999962270259857177734375,
            0.4999794065952301025390625, 0.1679216921329498291015625,
            4.19921875e-2
        ],
        5: [
            1.00000011920928955078125, 1, 0.49998724460601806640625,
            0.16666592657566070556640625, 4.1927136480808258056640625e-2,
            8.36181640625e-3
        ]
    },
    2: {
        2: [1.0000140666961669921875, 1.0037591457366943359375, 0.5],
        3: [
            0.999995291233062744140625, 0.9999973773956298828125,
            0.501252651214599609375, 0.1669921875
        ],
        4: [
            1, 0.999997675418853759765625, 0.499998867511749267578125,
            0.166978657245635986328125, 4.1748046875e-2
        ],
        5: [
            1, 1, 0.499999463558197021484375, 0.1666660606861114501953125,
            4.172463715076446533203125e-2, 8.36181640625e-3
        ]
    },
    4: {
        2: [1.00000083446502685546875, 1.00093877315521240234375, 0.5],
        3: [
            0.999999701976776123046875, 0.99999845027923583984375,
            0.50031363964080810546875, 0.1669921875
        ],
        4: [
            1, 0.99999988079071044921875, 0.4999995529651641845703125,
            0.16674046218395233154296875, 4.1748046875e-2
        ],
        5: [
            1, 1, 0.4999999701976776123046875, 0.16666649281978607177734375,
            4.168058931827545166015625e-2, 8.36181640625e-3
        ]
    },
    8: {
        2: [1, 1.0002346038818359375, 0.5],
        3: [
            1, 0.999999582767486572265625, 0.500065147876739501953125,
            0.1669921875
        ],
        4: [
            1, 1, 0.49999988079071044921875, 0.16668035089969635009765625,
            4.1748046875e-2
        ],
        5: [
            1, 1, 0.5, 0.16666662693023681640625, 4.166901111602783203125e-2,
            8.36181640625e-3
        ]
    },
    16: {
        2: [1, 1.000058650970458984375, 0.5],
        3: [
            1, 0.99999988079071044921875, 0.500016272068023681640625,
            0.1669921875
        ],
        4: [
            1, 1, 0.4999999701976776123046875, 0.1666700839996337890625,
            4.1748046875e-2
        ],
        5: [
            1, 1, 0.5, 0.1666666567325592041015625,
            4.16672527790069580078125e-2, 8.36181640625e-3
        ]
    },
    32: {
        2: [1, 1.00001466274261474609375, 0.5],
        3: [1, 1, 0.5000040531158447265625, 0.1669921875],
        4: [
            1, 1, 0.4999999701976776123046875, 0.166667520999908447265625,
            4.19921875e-2
        ],
        5: [
            1, 1, 0.5, 0.16666667163372039794921875,
            4.1666813194751739501953125e-2, 8.30078125e-3
        ]
    },
    64: {
        2: [1, 1.00000369548797607421875, 0.5],
        3: [1, 1, 0.500001013278961181640625, 0.1669921875],
        4: [1, 1, 0.5, 0.16666688024997711181640625, 4.1748046875e-2],
        5: [
            1, 1, 0.5, 0.16666667163372039794921875,
            4.16667051613330841064453125e-2, 8.11767578125e-3
        ]
    },
    128: {
        2: [1, 1.00000095367431640625, 0.5],
        3: [1, 1, 0.5000002384185791015625, 0.1669921875],
        4: [1, 1, 0.5, 0.1666667163372039794921875, 4.1748046875e-2],
        5: [
            1, 1, 0.5, 0.16666667163372039794921875,
            4.16666753590106964111328125e-2, 7.568359375e-3
        ]
    },
    256: {
        2: [1, 1.0000002384185791015625, 0.5],
        3: [1, 1, 0.500000059604644775390625, 0.1669921875],
        4: [1, 1, 0.5, 0.166666686534881591796875, 4.1748046875e-2],
        5: [
            1, 1, 0.5, 0.16666667163372039794921875,
            4.16666679084300994873046875e-2, 5.31005859375e-3
        ]
    },
    512: {
        2: [1, 1, 0.5],
        3: [1, 1, 0.5, 0.1669921875],
        4: [1, 1, 0.5, 0.16666667163372039794921875, 4.1748046875e-2],
        # sollya command: fpminimax(exp(x), 5, [|24, 24, 24, 24, 24, 8|], [-0.5*log(2.0)/512, 0.5*log(2.0)/512], floating, absolute);
        5: [
            1, 1, 0.5, 0.16666667163372039794921875,
            4.16666679084300994873046875e-2, -3.7078857421875e-3
        ]
    }
}

# Sollya, fp32 c_0~c_{n-1}, bf16 c_n.
GLUE_TANH_POSITIVE_POLY_COEFFCIENTS = {
    # sollya command: fpminimax(tanh(0.7978845608028654 * (x+0.044715 * x^3.0)), 2, [|24,24,8|], [0; 2.5], floating, absolute);
    2: [8.134123869240283966064453125e-3, 0.846239566802978515625, -0.18359375],
    # sollya command: fpminimax(tanh(0.7978845608028654 * (x+0.044715 * x^3.0)), 3, [|24,24,24,8|], [0; 2.5], floating, absolute);
    3: [
        -7.6428051106631755828857421875e-3, 0.906318247318267822265625,
        -0.23573850095272064208984375, 1.251220703125e-2
    ],
    # sollya command: fpminimax(tanh(0.7978845608028654 * (x+0.044715 * x^3.0)), 4, [|24,24,24,24,8|], [0; 2.5], floating, absolute);
    4: [
        -1.59433088265359401702880859375e-3, 0.82405579090118408203125,
        -6.302885711193084716796875e-2, -0.10210405290126800537109375,
        2.35595703125e-2
    ],
    # sollya command: fpminimax(tanh(0.7978845608028654 * (x+0.044715 * x^3.0)), 5, [|24,24,24,24,24,8|], [0; 2.5], floating, absolute);
    5: [
        1.89869213500060141086578369140625e-4, 0.7912051677703857421875,
        3.79874520003795623779296875e-2, -0.21261592209339141845703125,
        7.3306463658809661865234375e-2, -7.87353515625e-3
    ],
}

# Sollya, fp32 c_0~c_{n-1}, bf16 c_n.
TANH_POSITIVE_POLY_COEFFCIENTS = {
    # sollya command: fpminimax(tanh(x), 3, [|24,24,24,8|], [0; 4.0], floating, absolute);
    3: [
        1.4900849200785160064697265625e-2, 1.06381762027740478515625,
        -0.377149283885955810546875, 4.345703125e-2
    ],
    # sollya command: fpminimax(tanh(x), 4, [|24,24,24,24,8|], [0; 4.0], floating, absolute);
    4: [
        -9.55459661781787872314453125e-3, 1.1741485595703125,
        -0.4937419593334197998046875, 8.647145330905914306640625e-2,
        -5.126953125e-3
    ],
    # sollya command: fpminimax(tanh(x), 5, [|24,24,24,24,24,8|], [0; 4.0], floating, absolute);
    5: [
        -6.0254074633121490478515625e-3, 1.11906754970550537109375,
        -0.3621589839458465576171875, -1.8923975527286529541015625e-2,
        2.824944444000720977783203125e-2, -3.6163330078125e-3
    ],
    # # sollya command: fpminimax(tanh(x), 6, [|24,24,24,24,24,24,8|], [0; 4.0], floating, absolute);
    6: [
        -2.11274274624884128570556640625e-3, 1.04248869419097900390625,
        -0.12120990455150604248046875, -0.292342960834503173828125,
        0.16680581867694854736328125, -3.56756933033466339111328125e-2,
        2.7618408203125e-3
    ],
    # # sollya command: fpminimax(tanh(x), 7, [|24,24,24,24,24,24,24,8|], [0; 4.0], floating, absolute);
    7: [
        -2.822799724526703357696533203125e-4, 0.996061742305755615234375,
        6.5993212163448333740234375e-2, -0.576497375965118408203125,
        0.37244832515716552734375, -0.112049378454685211181640625,
        1.68143846094608306884765625e-2, -1.01470947265625e-3
    ],
    # 8: [],
    # 9: [],
    # 10: [],
}

def get_exp_poly_coeffcients(exp_table_size=1, degree=2):
  assert exp_table_size >= 1 and exp_table_size <= 512
  assert degree >= 2 and degree <= 5

  cs = np.array(EXP_POLY_COEFFCIENTS[exp_table_size][degree], dtype=np.float32)

  return cs

def get_sigmoid_positive_ploy_coeffcients(degree=4):
  if not degree in list(SIGMOID_POSITIVE_PLOY_COEFFCIENTS.keys()):
    raise NotImplementedError(
        f'degree must in: {list(SIGMOID_POSITIVE_PLOY_COEFFCIENTS.keys())}')
  cs = np.array(SIGMOID_POSITIVE_PLOY_COEFFCIENTS[degree], dtype=np.float32)
  return cs

def get_gelu_tanh_poly_coeffcients(degree=2):
  cs = np.array(GLUE_TANH_POSITIVE_POLY_COEFFCIENTS[degree], dtype=np.float32)
  return cs

def get_tanh_positive_poly_coeffcients(degree=4):
  cs = np.array(TANH_POSITIVE_POLY_COEFFCIENTS[degree], dtype=np.float32)
  return cs
